{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# An Introduction to SageMaker ObjectToVec model for sequence-sequence embedding\n",
    "\n",
    "## Table of contents\n",
    "\n",
    "1. [Background](#Background)\n",
    "1. [Download datasets](#Download-datasets)\n",
    "1. [Preprocessing](#Preprocessing)\n",
    "1. [Model training and inference](#Model-training-and-inference)\n",
    "1. [Transfer learning with object2vec](#Transfer-learning)\n",
    "1. [How to enable the optimal training result](#How-to-enable-the-optimal-training-result)\n",
    "1. [Hyperparameter Tuning (Advanced)](#Hyperparameter-Tuning-(Advanced))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Background\n",
    "\n",
    "*Object2Vec* is a highly customizable multi-purpose algorithm that can learn embeddings of pairs of objects. The embeddings are learned in a way that it preserves their pairwise **similarities**\n",
    "- **Similarity** is user-defined: users need to provide the algorithm with pairs of objects that they define as similar (1) or dissimilar (0); alternatively, the users can define similarity in a continuous sense (provide a real-valued similarity score for reach object pair)\n",
    "- The learned embeddings can be used to compute nearest neighbors of objects, as well as to visualize natural clusters of related objects in the embedding space. In addition, the embeddings can also be used as features of the corresponding objects in downstream supervised tasks such as classification or regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using Object2Vec to Encode Sentences into Fixed Length Embeddings "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will demonstrate how to train *Object2Vec* to encode sequences of varying length into fixed length embeddings. \n",
    "\n",
    "As a specific example, we will represent each sentence as a sequence of integers, and we will show how to learn an encoder to embed these sentences into fixed-length vectors. To this end, we need pairs of sentences with labels that indicate their similarity. The Stanford Natural Language Inference data set (https://nlp.stanford.edu/projects/snli/), which consists\n",
    "of pairs of sentences labeled as \"entailment\", \"neutral\" or \"contradiction\", comes close to our requirements; we will pick this data set as our training dataset in this notebook example. \n",
    "\n",
    "Once the model is trained on this data,\n",
    "the trained encoders can be used to convert any new English sentences into fixed length embeddings. We will measure the quality of learned sentence embeddings on new sentences, by computing similarity of sentence pairs in the embedding space from the STS'16 dataset (http://alt.qcri.org/semeval2016/task1/), and evaluating against human-labeled ground-truth ratings."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img style=\"float:middle\" src=\"image_snli.png\" width=\"480\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Before running the notebook\n",
    "- Please use a Python 3 kernel for the notebook\n",
    "- Please make sure you have `jsonlines` and `nltk` packages installed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### (If you haven't done it) install jsonlines and nltk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!sudo pip install -U nltk\n",
    "!pip install jsonlines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please be aware of the following requirements about ackonwledgment, copyright and availability, cited from the [dataset description page](https://nlp.stanford.edu/projects/snli/).\n",
    "> The Stanford Natural Language Inference Corpus by The Stanford NLP Group is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.\n",
    "Based on a work at http://shannon.cs.illinois.edu/DenotationGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "import io\n",
    "import numpy as np\n",
    "from zipfile import ZipFile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SNLI_PATH = 'snli_1.0'\n",
    "STS_PATH = 'sts2016-english-with-gs-v1.0'\n",
    "\n",
    "if not os.path.exists(SNLI_PATH):\n",
    "    url_address = \"https://nlp.stanford.edu/projects/snli/snli_1.0.zip\"\n",
    "    request = requests.get(url_address)\n",
    "    zfile = ZipFile(io.BytesIO(request.content))\n",
    "    zfile.extractall()\n",
    "    zfile.close()\n",
    "\n",
    "if not os.path.exists(STS_PATH):\n",
    "    url_address = \"http://alt.qcri.org/semeval2016/task1/data/uploads/sts2016-english-with-gs-v1.0.zip\"\n",
    "    request = requests.get(url_address)\n",
    "    zfile = ZipFile(io.BytesIO(request.content))\n",
    "    zfile.extractall()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3 \n",
    "import sys, os\n",
    "import jsonlines\n",
    "import json\n",
    "from collections import Counter\n",
    "from itertools import chain, islice\n",
    "from nltk.tokenize import TreebankWordTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# constants\n",
    "\n",
    "BOS_SYMBOL = \"<s>\"\n",
    "EOS_SYMBOL = \"</s>\"\n",
    "UNK_SYMBOL = \"<unk>\"\n",
    "PAD_SYMBOL = \"<pad>\"\n",
    "PAD_ID = 0\n",
    "TOKEN_SEPARATOR = \" \"\n",
    "VOCAB_SYMBOLS = [PAD_SYMBOL, UNK_SYMBOL, BOS_SYMBOL, EOS_SYMBOL]\n",
    "\n",
    " \n",
    "LABEL_DICT = {'entailment':0, 'neutral':1, 'contradiction':2}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Utility functions\n",
    "\n",
    "def read_jsonline(fname):\n",
    "    \"\"\"\n",
    "    Reads jsonline files and returns iterator\n",
    "    \"\"\"\n",
    "    with jsonlines.open(fname) as reader:\n",
    "        for line in reader:\n",
    "            yield line\n",
    "\n",
    "def sentence_to_integers(sentence, tokenizer, word_dict):\n",
    "    \"\"\"\n",
    "    Converts a string of tokens to a list of integers\n",
    "    TODO: Better handling of the case \n",
    "          where token is not in word_dict\n",
    "    \"\"\"\n",
    "    return [word_dict[token] for token in get_tokens(sentence, tokenizer)\n",
    "           if token in word_dict]\n",
    "\n",
    "\n",
    "def get_tokens(line, tokenizer):\n",
    "    \"\"\"\n",
    "    Yields tokens from input string.\n",
    "\n",
    "    :param line: Input string.\n",
    "    :return: Iterator over tokens.\n",
    "    \"\"\"\n",
    "    for token in tokenizer.tokenize(line):\n",
    "        if len(token) > 0:\n",
    "            yield token\n",
    "\n",
    "            \n",
    "def get_tokens_from_snli(input_dict, tokenizer):\n",
    "    iter_list = list()\n",
    "    for sentence_key in ['sentence1', 'sentence2']:\n",
    "        sentence = input_dict[sentence_key]\n",
    "        iter_list.append(get_tokens(sentence, tokenizer))\n",
    "    return chain(iter_list[0], iter_list[1])\n",
    "\n",
    "\n",
    "def get_tokens_from_sts(input_sentence_pair, tokenizer):\n",
    "    iter_list = list()\n",
    "    for s in input_sentence_pair:\n",
    "        iter_list.append(get_tokens(s, tokenizer))\n",
    "    return chain(iter_list[0], iter_list[1])\n",
    "\n",
    "\n",
    "def resolve_snli_label(raw_label):\n",
    "    \"\"\"\n",
    "    Converts raw label to integer\n",
    "    \"\"\"\n",
    "    return LABEL_DICT[raw_label]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Functions to build vocabulary from SNLI corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_vocab(data_iter, dataname='snli', num_words=50000, min_count=1, use_reserved_symbols=True, sort=True):\n",
    "    \"\"\"\n",
    "    Creates a vocabulary mapping from words to ids. Increasing integer ids are assigned by word frequency,\n",
    "    using lexical sorting as a tie breaker. The only exception to this are special symbols such as the padding symbol\n",
    "    (PAD).\n",
    "\n",
    "    :param data_iter: Sequence of sentences containing whitespace delimited tokens.\n",
    "    :param num_words: Maximum number of words in the vocabulary.\n",
    "    :param min_count: Minimum occurrences of words to be included in the vocabulary.\n",
    "    :return: word-to-id mapping.\n",
    "    \"\"\"\n",
    "    vocab_symbols_set = set(VOCAB_SYMBOLS)\n",
    "    tokenizer = TreebankWordTokenizer()\n",
    "    if dataname == 'snli':\n",
    "        raw_vocab = Counter(token for line in data_iter for token in get_tokens_from_snli(line, tokenizer)\n",
    "                        if token not in vocab_symbols_set)\n",
    "    elif dataname == 'sts':\n",
    "        raw_vocab = Counter(token for line in data_iter for token in get_tokens_from_sts(line, tokenizer) \n",
    "                            if token not in vocab_symbols_set)\n",
    "    else:\n",
    "        raise NameError(f'Data name {dataname} is not recognized!')\n",
    "        \n",
    "    print(\"Initial vocabulary: {} types\".format(len(raw_vocab)))\n",
    "\n",
    "    # For words with the same count, they will be ordered reverse alphabetically.\n",
    "    # Not an issue since we only care for consistency\n",
    "    pruned_vocab = sorted(((c, w) for w, c in raw_vocab.items() if c >= min_count), reverse=True)\n",
    "    print(\"Pruned vocabulary: {} types (min frequency {})\".format(len(pruned_vocab), min_count))\n",
    "    \n",
    "    # truncate the vocabulary to fit size num_words (only includes the most frequent ones)\n",
    "    vocab = islice((w for c, w in pruned_vocab), num_words)\n",
    "\n",
    "    if sort:\n",
    "        # sort the vocabulary alphabetically\n",
    "        vocab = sorted(vocab)\n",
    "    if use_reserved_symbols:\n",
    "        vocab = chain(VOCAB_SYMBOLS, vocab)\n",
    "    \n",
    "    word_to_id = {word: idx for idx, word in enumerate(vocab)}\n",
    "\n",
    "    print(\"Final vocabulary: {} types\".format(len(word_to_id)))\n",
    "\n",
    "    if use_reserved_symbols:\n",
    "        # Important: pad symbol becomes index 0\n",
    "        assert word_to_id[PAD_SYMBOL] == PAD_ID\n",
    "    return word_to_id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Functions to convert SNLI data to pairs of sequences of integers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_snli_to_integers(data_iter, word_to_id, dirname=SNLI_PATH, fname_suffix=\"\"):\n",
    "    \"\"\"\n",
    "    Go through snli jsonline file line by line and convert sentences to list of integers\n",
    "    - convert entailments to labels\n",
    "    \"\"\" \n",
    "    fname = 'snli-integer-' + fname_suffix + '.jsonl'\n",
    "    path = os.path.join(dirname, fname)\n",
    "    tokenizer = TreebankWordTokenizer()\n",
    "    count = 0\n",
    "    max_seq_length = 0\n",
    "    with jsonlines.open(path, mode='w') as writer:\n",
    "        for in_dict in data_iter:\n",
    "            #in_dict = json.loads(line)\n",
    "            out_dict = dict()\n",
    "            rlabel = in_dict['gold_label']\n",
    "            if rlabel in LABEL_DICT:\n",
    "                rsentence1 = in_dict['sentence1']\n",
    "                rsentence2 = in_dict['sentence2']\n",
    "                for idx, sentence in enumerate([rsentence1, rsentence2]):\n",
    "                    #print(count, sentence)\n",
    "                    s = sentence_to_integers(sentence, tokenizer, word_to_id)\n",
    "                    out_dict[f'in{idx}'] = s\n",
    "                    count += 1\n",
    "                    max_seq_length = max(len(s), max_seq_length)\n",
    "                out_dict['label'] = resolve_snli_label(rlabel)\n",
    "                writer.write(out_dict)\n",
    "            else:\n",
    "                count += 1\n",
    "    print(f\"There are in total {count} invalid labels\")\n",
    "    print(f\"The max length of converted sequence is {max_seq_length}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate vocabulary from SNLI data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_snli_full_vocab(dirname=SNLI_PATH, force=True):\n",
    "    vocab_path = os.path.join(dirname, 'snli-vocab.json')\n",
    "    if not os.path.exists(vocab_path) or force:\n",
    "        data_iter_list = list()\n",
    "        for fname_suffix in [\"train\", \"test\", \"dev\"]:\n",
    "            fname = \"snli_1.0_\" + fname_suffix + \".jsonl\"\n",
    "            data_iter_list.append(read_jsonline(os.path.join(dirname, fname)))\n",
    "        data_iter = chain(data_iter_list[0], data_iter_list[1], data_iter_list[2])\n",
    "        with open(vocab_path, \"w\") as write_file:\n",
    "            word_to_id = build_vocab(data_iter, num_words=50000, min_count=1, use_reserved_symbols=False, sort=True)\n",
    "            json.dump(word_to_id, write_file)\n",
    "\n",
    "make_snli_full_vocab(force=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate tokenized SNLI data as sequences of integers\n",
    "- We use the SNLI vocabulary as a lookup dictionary to convert SNLI sentence pairs into sequences of integers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_snli_data(dirname=SNLI_PATH, vocab_file='snli-vocab.json', outfile_suffix=\"\", force=True):\n",
    "    for fname_suffix in [\"train\", \"test\", \"validation\"]:\n",
    "        outpath = os.path.join(dirname, f'snli-integer-{fname_suffix}-{outfile_suffix}.jsonl')\n",
    "        if not os.path.exists(outpath) or force:\n",
    "            if fname_suffix=='validation':\n",
    "                inpath = os.path.join(dirname, f'snli_1.0_dev.jsonl')\n",
    "            else:\n",
    "                inpath = os.path.join(dirname, f'snli_1.0_{fname_suffix}.jsonl')\n",
    "            data_iter = read_jsonline(inpath)\n",
    "            vocab_path = os.path.join(dirname, vocab_file)\n",
    "            with open(vocab_path, \"r\") as f:\n",
    "                word_to_id = json.load(f)   \n",
    "            convert_snli_to_integers(data_iter, word_to_id, dirname=dirname, \n",
    "                                     fname_suffix=f'{fname_suffix}-{outfile_suffix}')\n",
    "\n",
    "            \n",
    "make_snli_data(force=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model training and inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_vocab_size(vocab_path):\n",
    "    with open(vocab_path) as f:\n",
    "        word_to_id = json.load(f)\n",
    "        print(f\"There are {len(word_to_id.keys())} words in vocabulary {vocab_path}\")\n",
    "    \n",
    "\n",
    "vocab_path = os.path.join(SNLI_PATH, 'snli-vocab.json')\n",
    "print_vocab_size(vocab_path)   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the runs in this notebook, we will use the Hierarchical CNN architecture to encode each of the sentences into fixed length embeddings. Some of the other hyperparameters are shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define hyperparameters and define S3 input path\n",
    "DEFAULT_HP = {\n",
    "  \"enc_dim\": 4096,\n",
    "  \"mlp_dim\": 512,\n",
    "  \"mlp_activation\": \"linear\",\n",
    "  \"mlp_layers\": 2,\n",
    "  \"output_layer\" : \"softmax\",\n",
    "\n",
    "  \"optimizer\" : \"adam\",\n",
    "  \"learning_rate\" : 0.0004,\n",
    "  \"mini_batch_size\": 32,\n",
    "  \"epochs\" : 20,\n",
    "  \"bucket_width\": 0,\n",
    "\n",
    "  \"early_stopping_tolerance\" : 0.01,\n",
    "  \"early_stopping_patience\" : 3,\n",
    "\n",
    "  \"dropout\": 0,\n",
    "  \"weight_decay\": 0,\n",
    "\n",
    "  \"enc0_max_seq_len\": 82,\n",
    "  \"enc1_max_seq_len\": 82,\n",
    "\n",
    "  \"enc0_network\": \"hcnn\",\n",
    "  \"enc1_network\": \"enc0\",\n",
    "\n",
    "  \"enc0_token_embedding_dim\": 300,\n",
    "  \"enc0_layers\": \"auto\",\n",
    "  \"enc0_cnn_filter_width\": 3,\n",
    "\n",
    "  \"enc1_token_embedding_dim\": 300,\n",
    "  \"enc1_layers\": \"auto\",\n",
    "  \"enc1_cnn_filter_width\": 3,\n",
    "\n",
    "  \"enc0_vocab_file\" : \"\",\n",
    "  \"enc1_vocab_file\" : \"\",\n",
    "\n",
    "  \"enc0_vocab_size\" : 43533,\n",
    "  \"enc1_vocab_size\" : 43533,\n",
    "\n",
    "  \"num_classes\": 3,\n",
    "\n",
    "  \"_num_gpus\" : \"auto\",\n",
    "  \"_num_kv_servers\" : \"auto\",\n",
    "  \"_kvstore\" : \"device\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define input data channel and output path in S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Input data bucket and prefix\n",
    "\n",
    "bucket = '<your bucket name>' # Customize your bucket\n",
    "prefix = 'object2vec/input/' \n",
    "input_path = os.path.join('s3://', bucket, prefix)\n",
    "print(f\"Data path for training is {input_path}\")\n",
    "## Output path\n",
    "output_prefix = 'object2vec/output/'\n",
    "output_bucket = bucket\n",
    "output_path = os.path.join('s3://', output_bucket, output_prefix)\n",
    "print(f\"Trained model will be saved at {output_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize Sagemaker estimator \n",
    "- Get IAM role ObjectToVec algorithm image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "\n",
    "\n",
    "role = get_execution_role()\n",
    "print(role)\n",
    "\n",
    "## Get docker image of ObjectToVec algorithm\n",
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "container = get_image_uri(boto3.Session().region_name, 'object2vec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.session import s3_input\n",
    "\n",
    "\n",
    "def set_training_environment(bucket, prefix, base_hyperparameters=DEFAULT_HP,\n",
    "                             is_quick_run=True, is_pretrain=False, use_all_vocab=False):\n",
    "    \n",
    "    input_channels = {}\n",
    "    s3_client = boto3.client('s3')\n",
    "    for split in ['train', 'validation']:\n",
    "        if is_pretrain:\n",
    "            fname_in = f'all_vocab_datasets/snli-integer-{split}-pretrain.jsonl'\n",
    "            fname_out = f'{split}/snli-integer-{split}-pretrain.jsonl'\n",
    "        else:\n",
    "            fname_in = os.path.join(SNLI_PATH, f'snli-integer-{split}-.jsonl')\n",
    "            fname_out = f'{split}/snli-integer-{split}.jsonl'\n",
    "        \n",
    "        s3_client.upload_file(fname_in, bucket, os.path.join(prefix, fname_out))\n",
    "        input_channels[split] = s3_input(input_path + fname_out, \n",
    "                                 distribution='ShardedByS3Key', \n",
    "                                 content_type='application/jsonlines')\n",
    "    \n",
    "        print('Uploaded {} data to {}'.format(split, input_path + fname_out))\n",
    "    \n",
    "    hyperparameters = base_hyperparameters.copy()\n",
    "    \n",
    "    if use_all_vocab:\n",
    "        hyperparameters['enc0_vocab_file'] = 'all_vocab.json'\n",
    "        hyperparameters['enc1_vocab_file'] = 'all_vocab.json'\n",
    "        hyperparameters['enc0_vocab_size'] = 43662\n",
    "        hyperparameters['enc1_vocab_size'] = 43662\n",
    "\n",
    "    if is_pretrain:\n",
    "        ## set up auxliary channel\n",
    "        aux_path = os.path.join(prefix, \"auxiliary\")\n",
    "        # upload auxiliary files\n",
    "        assert os.path.exists(\"GloVe/glove.840B-trim.txt\"), \"Pretrained embedding does not exist!\"\n",
    "        s3_client.upload_file(\"GloVe/glove.840B-trim.txt\", bucket, os.path.join(aux_path, 'glove.840B-trim.txt'))\n",
    "        if use_all_vocab:\n",
    "            s3_client.upload_file('all_vocab_datasets/all_vocab.json', \n",
    "                                  bucket, os.path.join(aux_path, 'all_vocab.json'))\n",
    "        else:\n",
    "            s3_client.upload_file(\"snli_1.0/snli-vocab.json\", \n",
    "                                  bucket, os.path.join(aux_path, \"snli-vocab.json\"))\n",
    "\n",
    "        input_channels['auxiliary'] = s3_input('s3://' + bucket + '/' + aux_path, \n",
    "                                     distribution='FullyReplicated', content_type='application/json')\n",
    "        \n",
    "        print('Uploaded auxiliary data for initializing with pretrain-embedding to {}'.format(aux_path))\n",
    "        \n",
    "        # add pretrained_embedding_file name to hyperparameters\n",
    "        for idx in [0, 1]:\n",
    "            hyperparameters[f'enc{idx}_pretrained_embedding_file'] = 'glove.840B-trim.txt'\n",
    "\n",
    "    if is_quick_run:\n",
    "        hyperparameters['mini_batch_size'] = 8192\n",
    "        hyperparameters['enc_dim'] = 16\n",
    "        hyperparameters['epochs'] = 2\n",
    "    else:\n",
    "        hyperparameters['mini_batch_size'] = 256\n",
    "        hyperparameters['enc_dim'] = 8192\n",
    "        hyperparameters['epochs'] = 20\n",
    "    return hyperparameters, input_channels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train without using pretrained embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## get estimator\n",
    "regressor = sagemaker.estimator.Estimator(container,\n",
    "                                          role, \n",
    "                                          train_instance_count=1, \n",
    "                                          train_instance_type='ml.p2.xlarge',\n",
    "                                          output_path=output_path,\n",
    "                                          sagemaker_session=sess)\n",
    "\n",
    "\n",
    "## set up training environment\n",
    "\"\"\"\n",
    "- To get good training result, set is_quick_run to False \n",
    "- To test-run the algorithm quickly, set is_quick_run to True\n",
    "\"\"\"\n",
    "hyperparameters, input_channels = set_training_environment(bucket, prefix, \n",
    "                                                           is_quick_run=True, \n",
    "                                                           is_pretrain=False, use_all_vocab=False)\n",
    "\n",
    "regressor.set_hyperparameters(**hyperparameters)\n",
    "regressor.hyperparameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "regressor.fit(input_channels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot evaluation metrics for training job\n",
    "\n",
    "Evaluation metrics for the completed training job are available in CloudWatch. We can pull the cross entropy metric of the validation data set and plot it to see the performance of the model over time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from sagemaker.analytics import TrainingJobAnalytics\n",
    "\n",
    "latest_job_name = regressor.latest_training_job.job_name\n",
    "metric_name = 'validation:cross_entropy'\n",
    "\n",
    "metrics_dataframe = TrainingJobAnalytics(training_job_name=latest_job_name, metric_names=[metric_name]).dataframe()\n",
    "plt = metrics_dataframe.plot(kind='line', figsize=(12,5), x='timestamp', y='value', style='b.', legend=False)\n",
    "plt.set_ylabel(metric_name);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deploy trained algorithm and set input-output configuration for inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.predictor import json_serializer, json_deserializer\n",
    "\n",
    "# deploy model and create endpoint and with customer-defined endpoint_name\n",
    "predictor1 = regressor.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define encode-decode format for inference data\n",
    "predictor1.content_type = 'application/json'\n",
    "predictor1.serializer = json_serializer\n",
    "predictor1.deserializer = json_deserializer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Invoke endpoint and do inference with trained model\n",
    "- Suppose we deploy our trained model with the endpoint_name \"seqseq-prelim-with-pretrain-3\". Now we demonstrate how to do inference using our earlier model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_prediction_accuracy(predictions, labels):\n",
    "    loss = 0\n",
    "    for idx, s_and_l in enumerate(zip(predictions['predictions'], labels)):\n",
    "        score, label = s_and_l\n",
    "        plabel = np.argmax(score['scores'])\n",
    "        loss += int(plabel != label['label'])\n",
    "    return 1 - loss / len(labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Send mini-batches of SNLI test data to the endpoint and evaluate our model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import sagemaker\n",
    "from sagemaker.predictor import json_serializer, json_deserializer\n",
    "\n",
    "# load SNLI test data\n",
    "snli_test_path = os.path.join(SNLI_PATH, 'snli-integer-test-.jsonl')\n",
    "test_data_content = list()\n",
    "test_label = list()\n",
    "\n",
    "for line in read_jsonline(snli_test_path):\n",
    "    test_data_content.append({'in0':line['in0'], 'in1':line['in1']})\n",
    "    test_label.append({'label': line['label']})\n",
    "\n",
    "print(\"Evaluating test results on SNLI without pre-trained embedding...\")\n",
    "\n",
    "\n",
    "batch_size = 100\n",
    "n_test = len(test_label)\n",
    "n_batches = math.ceil(n_test / float(batch_size))\n",
    "start = 0\n",
    "agg_acc = 0\n",
    "for idx in range(n_batches):\n",
    "    if idx % 10 == 0:\n",
    "        print(f\"Evaluating the {idx+1}-th batch\")\n",
    "    end = (start + batch_size) if (start + batch_size) <= n_test else n_test\n",
    "    payload = {'instances': test_data_content[start:end]}\n",
    "    acc = calc_prediction_accuracy(predictor1.predict(payload), test_label[start:end])\n",
    "    agg_acc += acc * (end-start+1)\n",
    "    start = end\n",
    "print(f\"The test accuracy is {agg_acc/n_test}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transfer learning\n",
    "- We evaluate the trained model directly on STS16 **question-question** task\n",
    "- See SemEval-2016 Task 1 paper (http://www.aclweb.org/anthology/S16-1081) for an explanation of the evaluation method and benchmarking results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The cells below provide details on how to combine vocabulary for STS and SNLI,and how to get glove pretrained embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functions to generate STS evaluation set (from sts-2016-test set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadSTSFile(fpath=STS_PATH, datasets=['question-question']):\n",
    "    data = {}\n",
    "    for dataset in datasets:\n",
    "        sent1 = []\n",
    "        sent2 = []\n",
    "        for line in io.open(fpath + f'/STS2016.input.{dataset}.txt',\n",
    "                        encoding='utf8').read().splitlines():\n",
    "            splitted = line.split(\"\\t\")                \n",
    "            sent1.append(splitted[0])\n",
    "            sent2.append(splitted[1])\n",
    "        \n",
    "        raw_scores = np.array([x for x in\n",
    "                            io.open(fpath + f'/STS2016.gs.{dataset}.txt',\n",
    "                            encoding='utf8').read().splitlines()])\n",
    "        \n",
    "        not_empty_idx = raw_scores != ''\n",
    "\n",
    "        gs_scores = [float(x) for x in raw_scores[not_empty_idx]]\n",
    "        sent1 = np.array(sent1)[not_empty_idx]\n",
    "        sent2 = np.array(sent2)[not_empty_idx]\n",
    "\n",
    "        data[dataset] = (sent1, sent2, gs_scores)\n",
    "    \n",
    "    return data\n",
    "\n",
    "def get_sts_data_iterator(fpath=STS_PATH, datasets=['question-question']):\n",
    "    data = loadSTSFile(fpath, datasets)\n",
    "    for dataset in datasets:\n",
    "        sent1, sent2, _ = data[dataset]\n",
    "        for s1, s2 in zip(sent1, sent2):\n",
    "            yield [s1, s2]\n",
    "\n",
    "## preprocessing unit for STS test data\n",
    "\n",
    "def convert_single_sts_to_integers(s1, s2, gs_label, tokenizer, word_dict):\n",
    "    converted = []\n",
    "    for s in [s1, s2]:\n",
    "        converted.append(sentence_to_integers(s, tokenizer, word_dict))\n",
    "    converted.append(gs_label)\n",
    "    return converted\n",
    "\n",
    "\n",
    "def convert_sts_to_integers(sent1, sent2, gs_labels, tokenizer, word_dict):\n",
    "    for s1, s2, gs in zip(sent1, sent2, gs_labels):\n",
    "        yield convert_single_sts_to_integers(s1, s2, gs, tokenizer, word_dict)\n",
    "\n",
    "        \n",
    "\n",
    "def make_sts_data(fpath=STS_PATH, vocab_path_prefix=SNLI_PATH, \n",
    "                  vocab_name='snli-vocab.json', \n",
    "                  dataset='question-question'):\n",
    "    \"\"\"\n",
    "    prepare test data; example: test_data['left'] = [{'in0':[1,2,3]}, {'in0':[2,10]}, ...]\n",
    "    \"\"\"\n",
    "    test_data = {'left':[], 'right':[]}\n",
    "    test_label = list()\n",
    "    tokenizer = TreebankWordTokenizer()\n",
    "    vocab_path = os.path.join(vocab_path_prefix, vocab_name)\n",
    "    with open(vocab_path) as f:\n",
    "        word_dict = json.load(f)\n",
    "    data = loadSTSFile(fpath=fpath, datasets=[dataset])\n",
    "    for s1, s2, gs in convert_sts_to_integers(*data[dataset], tokenizer, word_dict):\n",
    "        test_data['left'].append({'in1': s1})\n",
    "        test_data['right'].append({'in1': s2})\n",
    "        test_label.append(gs)\n",
    "    return test_data, test_label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note, in `make_sts_data`, we pass both inputs (s1 and s2 to a single encoder; in this case, we pass them to 'in1'). This makes sure that both inputs are mapped by the same encoding function (we empirically found that this is crucial to achieve competitive embedding performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build vocabulary using STS corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_sts_full_vocab(dirname=STS_PATH, datasets=['question-question'], force=True):\n",
    "    vocab_path = os.path.join(dirname, 'sts-vocab.json')\n",
    "    if not os.path.exists(vocab_path) or force:\n",
    "        data_iter = get_sts_data_iterator(dirname, datasets)\n",
    "        with open(vocab_path, \"w\") as write_file:\n",
    "            word_to_id = build_vocab(data_iter, dataname='sts', \n",
    "                                     num_words=50000, min_count=1, \n",
    "                                     use_reserved_symbols=False, sort=True)\n",
    "            \n",
    "            json.dump(word_to_id, write_file)\n",
    "\n",
    "make_sts_full_vocab(force=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define functions for embedding evaluation on STS16 question-question task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import pearsonr, spearmanr\n",
    "import math\n",
    "\n",
    "\n",
    "def wrap_sts_test_data_for_eval(fpath=STS_PATH, vocab_path_prefix=\".\", \n",
    "                       vocab_name='all_vocab.json', dataset='question-question'):\n",
    "    \"\"\"\n",
    "    Prepare data for evaluation\n",
    "    \"\"\"\n",
    "    test_data, test_label = make_sts_data(fpath, vocab_path_prefix, vocab_name, dataset)\n",
    "    input1 = {\"instances\" : test_data['left']}\n",
    "    input2 = {\"instances\" : test_data['right']}\n",
    "    return [input1, input2, test_label]\n",
    "\n",
    "def get_cosine_similarity(vec1, vec2):\n",
    "    assert len(vec1)==len(vec2), \"Vector dimension mismatch!\"\n",
    "    norm1 = 0\n",
    "    norm2 = 0\n",
    "    inner_product = 0\n",
    "    for v1, v2 in zip(vec1, vec2):\n",
    "        norm1 += v1 ** 2\n",
    "        norm2 += v2 ** 2\n",
    "        inner_product += v1 * v2\n",
    "    return inner_product / math.sqrt(norm1 * norm2)\n",
    "\n",
    "def eval_corr(predictor, eval_data):\n",
    "    \"\"\"\n",
    "    input:\n",
    "    param: predictor: Sagemaker deployed model\n",
    "    eval_data: a list of [input1, inpu2, gs_scores]\n",
    "    Evaluate pearson and spearman correlation between algorithm's embedding and gold standard\n",
    "    \"\"\"\n",
    "    sys_scores = []\n",
    "    input1, input2, gs_scores = eval_data[0], eval_data[1], eval_data[2] # get this from make_sts_data\n",
    "    embeddings = []\n",
    "    for data in [input1, input2]:\n",
    "        prediction = predictor.predict(data)\n",
    "        embeddings.append(prediction['predictions'])\n",
    "    \n",
    "    for emb_pair in zip(embeddings[0], embeddings[1]):\n",
    "        emb1 = emb_pair[0]['embeddings']\n",
    "        emb2 = emb_pair[1]['embeddings']\n",
    "        sys_scores.append(get_cosine_similarity(emb1, emb2)) #TODO: implement this\n",
    "        \n",
    "    results = {'pearson': pearsonr(sys_scores, gs_scores),\n",
    "               'spearman': spearmanr(sys_scores, gs_scores),\n",
    "               'nsamples': len(sys_scores)}\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check overlap between SNLI and STS vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "snli_vocab_path = os.path.join(SNLI_PATH, 'snli-vocab.json')\n",
    "sts_vocab_path = os.path.join(STS_PATH, 'sts-vocab.json')\n",
    "\n",
    "with open(sts_vocab_path) as f:\n",
    "    sts_v = json.load(f)\n",
    "with open(snli_vocab_path) as f:\n",
    "    snli_v = json.load(f)\n",
    "\n",
    "sts_v_set = set(sts_v.keys())\n",
    "snli_v_set = set(snli_v.keys())\n",
    "\n",
    "print(len(sts_v_set))\n",
    "not_captured = sts_v_set.difference(snli_v_set)\n",
    "print(not_captured)\n",
    "print(f\"\\nThe number of words in STS not included in SNLI is {len(not_captured)}\")\n",
    "print(f\"\\nThis is {round(float(len(not_captured)/len(sts_v_set)), 2)} percent of the total STS vocabulary\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Since the percentage of vocabulary in STS not covered by SNLI is pretty large, we are going to include the uncovered words into our vocabulary and use the *GloVe* pretrained embedding to initialize our network. \n",
    "\n",
    "##### Intuitive reasoning for why this works\n",
    "\n",
    "* Our algorithm will not have seen the ***uncovered words*** during training\n",
    "* If we directly use integer representation of words during training, the unseen words will have zero correlation with words seen. \n",
    "  - This means the model cannot embed the unseen words in a manner that takes advantage of its training knowledge\n",
    "* However, if we use pre-trained word embedding, then we expect that some of the unseen words will be close to the words that the algorithm has seen in the embedding space "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_vocabulary(vocab_paths, new_vocab_path):\n",
    "    wd_count = 0\n",
    "    all_vocab = set()\n",
    "    new_vocab = {}\n",
    "    for vocab_path in vocab_paths:\n",
    "        with open(vocab_path) as f:\n",
    "            vocab = json.load(f)\n",
    "            all_vocab = all_vocab.union(vocab.keys())\n",
    "    for idx, wd in enumerate(all_vocab):\n",
    "        new_vocab[wd] = idx\n",
    "    print(f\"The new vocabulary size is {idx+1}\")\n",
    "    with open(new_vocab_path, 'w') as f:\n",
    "        json.dump(new_vocab, f)\n",
    "        \n",
    "vocab_paths = [snli_vocab_path, sts_vocab_path]\n",
    "new_vocab_path = \"all_vocab.json\"\n",
    "\n",
    "combine_vocabulary(vocab_paths, new_vocab_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get pre-trained GloVe word embedding and upload it to S3\n",
    "\n",
    "- Our notebook storage is not enough to host the *GloVe* file. Fortunately, we have extra space in the `/tmp` folder that we can utilize: https://docs.aws.amazon.com/sagemaker/latest/dg/howitworks-create-ws.html\n",
    "- You may use the bash script below to download and unzip *GloVe* in the `/tmp` folder and remove it after use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "# download glove file from website\n",
    "mkdir /tmp/GloVe\n",
    "curl -Lo /tmp/GloVe/glove.840B.zip http://nlp.stanford.edu/data/glove.840B.300d.zip\n",
    "unzip /tmp/GloVe/glove.840B.zip -d /tmp/GloVe/\n",
    "rm /tmp/GloVe/glove.840B.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We next trim the original *GloVe* embedding file so that it just covers our combined vocabulary, and then we save the trimmed glove file in the newly created *GloVe* directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir GloVe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# credit: This preprocessing function is modified from the w2v preprocessing script in Facebook infersent codebase\n",
    "# Infersent code license can be found at: https://github.com/facebookresearch/InferSent/blob/master/LICENSE\n",
    "\n",
    "def trim_w2v(in_path, out_path, word_dict):\n",
    "    # create word_vec with w2v vectors\n",
    "    lines = []\n",
    "    with open(out_path, 'w') as outfile:\n",
    "        with open(in_path) as f:\n",
    "            for line in f:\n",
    "                word, vec = line.split(' ', 1)\n",
    "                if word in word_dict:\n",
    "                    lines.append(line)\n",
    "\n",
    "        print('Found %s(/%s) words with w2v vectors' % (len(lines), len(word_dict)))\n",
    "        outfile.writelines(lines)\n",
    "\n",
    "in_path = '/tmp/GloVe/glove.840B.300d.txt'\n",
    "out_path = 'GloVe/glove.840B-trim.txt'\n",
    "with open('all_vocab.json') as f:\n",
    "    word_dict = json.load(f)\n",
    "\n",
    "trim_w2v(in_path, out_path, word_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# remember to remove the original GloVe embedding folder since it takes up a lot of space\n",
    "!rm -r /tmp/GloVe/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reprocess training data (SNLI) with the combined vocabulary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a new directory called `all_vocab_datasets`, and copy snli raw json files and all_vocab file to it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "\n",
    "mkdir all_vocab_datasets\n",
    "\n",
    "for SPLIT in train dev test\n",
    "do\n",
    "    cp snli_1.0/snli_1.0_${SPLIT}.jsonl all_vocab_datasets/\n",
    "done\n",
    "\n",
    "cp all_vocab.json all_vocab_datasets/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert snli data to integers using the all_vocab file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_snli_data(dirname=\"all_vocab_datasets\", vocab_file='all_vocab.json', outfile_suffix='pretrain', force=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reset training environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that when we combine the vocabulary of our training and test data, we should not fine-tune the GloVE embeddings, but instead, keep them fixed. Otherwise, it amounts to a bit of cheating -- training on test data! Thankfully, our hyper-parameter `enc0/1_freeze_pretrained_embedding` is set to `True` by default. Note that in the earlier training where we did not use pretrained embeddings, this parameter is inconsequential."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters_2, input_channels_2 = set_training_environment(bucket, prefix, \n",
    "                                                               is_quick_run=True, \n",
    "                                                               is_pretrain=True, \n",
    "                                                               use_all_vocab=True)\n",
    "\n",
    "\n",
    "\n",
    "# attach a new regressor to the old one using the previous training job endpoint\n",
    "# (this will also retrieve the log of the previous training job)\n",
    "training_job_name = regressor.latest_training_job.name\n",
    "new_regressor = regressor.attach(training_job_name, sagemaker_session=sess)\n",
    "new_regressor.set_hyperparameters(**hyperparameters_2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fit the new regressor using the new data (with pretrained embedding)\n",
    "new_regressor.fit(input_channels_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deploy and test the new model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor_2 = new_regressor.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')\n",
    "predictor_2.content_type = 'application/json'\n",
    "predictor_2.serializer = json_serializer\n",
    "predictor_2.deserializer = json_deserializer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first check the test error on SNLI after adding pretrained embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load SNLI test data\n",
    "snli_test_path = os.path.join(\"all_vocab_datasets\", 'snli-integer-test-pretrain.jsonl')\n",
    "test_data_content = list()\n",
    "test_label = list()\n",
    "\n",
    "for line in read_jsonline(snli_test_path):\n",
    "    test_data_content.append({'in0':line['in0'], 'in1':line['in1']})\n",
    "    test_label.append({'label': line['label']})\n",
    "\n",
    "print(\"Evaluating test results on SNLI with pre-trained embedding...\")\n",
    "\n",
    "batch_size = 100\n",
    "n_test = len(test_label)\n",
    "n_batches = math.ceil(n_test / float(batch_size))\n",
    "start = 0\n",
    "agg_acc = 0\n",
    "for idx in range(n_batches):\n",
    "    if idx % 10 == 0:\n",
    "        print(f\"Evaluating the {idx+1}-th batch\")\n",
    "    end = (start + batch_size) if (start + batch_size) <= n_test else n_test\n",
    "    payload = {'instances': test_data_content[start:end]}\n",
    "    acc = calc_prediction_accuracy(predictor_2.predict(payload), test_label[start:end])\n",
    "    agg_acc += acc * (end-start+1)\n",
    "    start = end\n",
    "print(f\"The test accuracy is {agg_acc/n_test}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "We next test the zero-shot transfer learning performance of our trained model on STS task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data_qq = wrap_sts_test_data_for_eval(fpath=STS_PATH, vocab_path_prefix=\"all_vocab_datasets\", \n",
    "                       vocab_name='all_vocab.json', dataset='question-question')\n",
    "\n",
    "results = eval_corr(predictor_2, eval_data_qq)\n",
    "\n",
    "pcorr = results['pearson'][0]\n",
    "spcorr = results['spearman'][0]\n",
    "print(f\"The Pearson correlation to gold standard labels is {pcorr}\")\n",
    "print(f\"The Spearman correlation to gold standard labels is {spcorr}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## clean up\n",
    "sess.delete_endpoint(predictor1.endpoint)\n",
    "sess.delete_endpoint(predictor_2.endpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to enable the optimal training result "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So far we have been training the algorithm with `is_quick_run` set to `True` (in `set_training_envirnoment` function); this is because we want to minimize the time for you to run through this notebook. If you want to yield the best performance of *Object2Vec* on the tasks above, we recommend setting `is_quick_run` to `False`. For example, with pretrained embedding used, we would re-run the code block under **Reset training environment** as the block below"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:red\">Run with caution</span>: \n",
    "This may take a few hours to complete depending on the machine instance you are using"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters_2, input_channels_2 = set_training_environment(bucket, prefix, \n",
    "                                                               is_quick_run=False, # modify is_quick_run flag here\n",
    "                                                               is_pretrain=True, \n",
    "                                                               use_all_vocab=True)\n",
    "\n",
    "training_job_name = regressor.latest_training_job.name\n",
    "new_regressor = regressor.attach(training_job_name, sagemaker_session=sess)\n",
    "new_regressor.set_hyperparameters(**hyperparameters_2)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can train and deploy the model as before; similarly, without pretrained embedding, the code block under **Train without using pretrained embedding** can be changed to below to optimize training result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color:red\">Run with caution</span>: \n",
    "This may take a few hours to complete depending on the machine instance you are using"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters, input_channels = set_training_environment(bucket, prefix, \n",
    "                                                           is_quick_run=False, # modify is_quick_run flag here\n",
    "                                                           is_pretrain=False, \n",
    "                                                           use_all_vocab=False)\n",
    "\n",
    "regressor.set_hyperparameters(**hyperparameters)\n",
    "regressor.hyperparameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Best training result\n",
    "\n",
    "With `is_quick_run = False` and without pretrained embedding, our algorithm's test accuracy on SNLI dataset is 78.5%; with pretrained GloVe embedding, we see an improved test accuracy on SNLI dataset to 81.9% ! On STS data, you should expect the Pearson correlation to be around 0.61.\n",
    "\n",
    "In addition to the training demonstrated in this notebook, we have also done benchmarking experiments on evaluated on both SNLI and STS data, with different hyperparameter configurations, which we include below.\n",
    "\n",
    "In both charts, we compare against Facebook's Infersent algorithm (https://research.fb.com/downloads/infersent/). The chart on the left shows the additional experiment result on SNLI (using CNN or RNN encoders). The chart on the right shows the best experiment result of Object2Vec on STS."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img style=\"float:left\" src=\"o2v-exp-snli.png\" width=\"430\"> \n",
    "<img style=\"float:middle\" src=\"o2v-exp-sts.png\" width=\"430\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Tuning (Advanced) \n",
    "with Hyperparameter Optimization (HPO) service in Sagemaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To yield optimal performance out of any machine learning algorithm often requires a lot of effort on parameter tuning. \n",
    "In this notebook demo, we have hidden the hard work of finding a combination of good parameters for the algorithm on SNLI data (again, the optimal parameters are only defined by running `set_training_environment` method with `is_quick_run=False`).\n",
    "\n",
    "If you are keen to explore how to tune HP on your own, you may find the code blocks below helpful."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To find the best HP combinations for our task, we can do parameter tuning by launching HPO jobs either from \n",
    "- As a simple example, we demonstrate how to find the best `enc_dim` parameter using HPO service here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_uri_path = {}\n",
    "\n",
    "for split in ['train', 'validation']:\n",
    "    s3_uri_path[split] = input_path + f'{split}/snli-integer-{split}.jsonl'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On a high level, a HPO tuning job is nothing but a collection of multiple training jobs with different HP setups; Sagemaker HPO service compares the performance of different training jobs according to the **HPO tuning metric**, which is specified in the `tuning_job_config`.\n",
    "\n",
    "- More info on how to manually launch hpo tuning jobs can be found here: \n",
    "https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-ex-tuning-job.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tuning_job_config = {\n",
    "    \"ParameterRanges\": {\n",
    "      \"CategoricalParameterRanges\": [\n",
    "      ],\n",
    "      \"ContinuousParameterRanges\": [],\n",
    "      \"IntegerParameterRanges\": [ \n",
    "        {\n",
    "            \"MaxValue\": \"1024\",\n",
    "            \"MinValue\": \"16\",\n",
    "            \"Name\": \"enc_dim\"\n",
    "        }\n",
    "      ],\n",
    "    },\n",
    "    \"ResourceLimits\": {\n",
    "      \"MaxNumberOfTrainingJobs\": 3,\n",
    "      \"MaxParallelTrainingJobs\": 3\n",
    "    },\n",
    "    \"Strategy\": \"Bayesian\",\n",
    "    \"HyperParameterTuningJobObjective\": {\n",
    "      \"MetricName\": \"validation:accuracy\",\n",
    "      \"Type\": \"Maximize\"\n",
    "    }\n",
    "  }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The tuning metric `MetricName` we use here is called `validation:accuracy`, together with `Type` set to `Maximize`, since we are trying to maximize accuracy here (in case you want to minimize mean squared error, you can switch the tuning objective accordingly to `validation:mean_squared_error` and `Minimize`)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The syntax for defining the configuration of an individual training job in a HPO job is as below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_job_definition = {\n",
    "    \"AlgorithmSpecification\": {\n",
    "      \"TrainingImage\": container,\n",
    "      \"TrainingInputMode\": \"File\"\n",
    "    },\n",
    "    \"InputDataConfig\": [\n",
    "      {\n",
    "        \"ChannelName\": \"train\",\n",
    "        \"CompressionType\": \"None\",\n",
    "        \"ContentType\": \"application/jsonlines\",\n",
    "        \"DataSource\": {\n",
    "          \"S3DataSource\": {\n",
    "            \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "            \"S3DataType\": \"S3Prefix\",\n",
    "            \"S3Uri\": s3_uri_path['train']\n",
    "          }\n",
    "        }\n",
    "      },\n",
    "      {\n",
    "        \"ChannelName\": \"validation\",\n",
    "        \"CompressionType\": \"None\",\n",
    "        \"ContentType\": \"application/jsonlines\",\n",
    "        \"DataSource\": {\n",
    "          \"S3DataSource\": {\n",
    "            \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "            \"S3DataType\": \"S3Prefix\",\n",
    "            \"S3Uri\": s3_uri_path['validation']\n",
    "          }\n",
    "        }\n",
    "      }\n",
    "    ],\n",
    "    \"OutputDataConfig\": {\n",
    "      \"S3OutputPath\": output_path\n",
    "    },\n",
    "    \"ResourceConfig\": {\n",
    "      \"InstanceCount\": 1,\n",
    "      \"InstanceType\": \"ml.p2.8xlarge\",\n",
    "      \"VolumeSizeInGB\": 20\n",
    "    },\n",
    "    \"RoleArn\": role,\n",
    "    \"StaticHyperParameters\": {\n",
    "             #'enc_dim': \"16\",  # do not include enc_dim here as static HP since we are tuning it\n",
    "             'learning_rate': '0.0004', \n",
    "             'mlp_dim': \"512\",\n",
    "             'mlp_activation': 'linear',\n",
    "             'mlp_layers': '2',\n",
    "             'output_layer': 'softmax',\n",
    "             'optimizer': 'adam',\n",
    "             'mini_batch_size': '8192',\n",
    "             'epochs': '2',\n",
    "             'bucket_width': '0',\n",
    "             'early_stopping_tolerance': '0.01',\n",
    "             'early_stopping_patience': '3',\n",
    "             'dropout': '0',\n",
    "             'weight_decay': '0',\n",
    "             'enc0_max_seq_len': '82',\n",
    "             'enc1_max_seq_len': '82',\n",
    "             'enc0_network': 'hcnn',\n",
    "             'enc1_network': 'enc0',\n",
    "             'enc0_token_embedding_dim': '300',\n",
    "             'enc0_layers': 'auto',\n",
    "             'enc0_cnn_filter_width': '3',\n",
    "             'enc1_token_embedding_dim': '300',\n",
    "             'enc1_layers': 'auto',\n",
    "             'enc1_cnn_filter_width': '3',\n",
    "             'enc0_vocab_file': '',\n",
    "             'enc1_vocab_file': '',\n",
    "             'enc0_vocab_size': '43533',\n",
    "             'enc1_vocab_size': '43533',\n",
    "             'num_classes': '3',\n",
    "             '_num_gpus': 'auto',\n",
    "             '_num_kv_servers': 'auto',\n",
    "             '_kvstore': 'device'},\n",
    "    \"StoppingCondition\": {\n",
    "      \"MaxRuntimeInSeconds\": 43200\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "sm_client = boto3.Session().client('sagemaker')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Disclaimer\n",
    "\n",
    "Running HPO tuning jobs means dispatching multiple training jobs with different HP setups; this could potentially incur a significant cost on your AWS account if you use the HP combinations that takes long hours to train.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tuning_job_name = \"hpo-o2v-test\"\n",
    "response = sm_client.create_hyper_parameter_tuning_job(HyperParameterTuningJobName = tuning_job_name,\n",
    "                                           HyperParameterTuningJobConfig = tuning_job_config,\n",
    "                                           TrainingJobDefinition = training_job_definition)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can then view and track the hyperparameter tuning jobs you launched on the sagemaker console (using the same account that you used to create the sagemaker client to launch these jobs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
